Skip to content

Conversation

@martinlsm
Copy link
Contributor

@martinlsm martinlsm commented Sep 16, 2025

Certain graphs caused infinite recursion when unioning nodes to groups based on shared quantization specs while implicit sharing was enabled. The problem occurred when NodeOrEdge A with its own QuantizationSpec received an edge (in shared_with_map) to EdgeOrNode B which in turn had a SharedQuantizationSpec pointing back to A. While in this state, looking up the quantization spec in _unwrap_shared_qspec resulted in an endless recursion loop and the program crashed.

Remedy this problem by, prior to unioning two trees, checking if the root of the parent has a SharedQuantizationSpec pointing to the root of the child; if that is the case, reverse the edge by letting parent point to the child. This will ensure that no cycle like the one described above is formed.

Add a test case which this commit fixes; namely, a graph where one input edge is shared between two ops. This graph would cause trigger cyclic reference bug that was seen prior to this patch.

Certain graphs caused infinite recursion when unioning nodes to groups
based on shared quantization specs while implicit sharing was enabled.
The problem occurred when `NodeOrEdge` A with its own `QuantizationSpec`
received an edge (in `shared_with_map`) to `EdgeOrNode` B which in turn
had a `SharedQuantizationSpec` pointing back to A.

Remedy this problem by checking if B, from the scenario above, has a
`SharedQuantizationSpec` pointing to A; if that is the case, don't union
them together by letting A point back to B. Avoiding the union/cycle
preserves correctness because the nodes are effectively already united.
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3011

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b5e4ea8 with merge base 4dffb40 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 16, 2025
@martinlsm
Copy link
Contributor Author

This PR is the proposed solution to the bug I initially reported here in the ExecuTorch repo: pytorch/executorch#13842

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 16, 2025

Thanks, can you add a test in https://github.com/pytorch/ao/blob/main/test/quantization/pt2e/test_quantize_pt2e.py for us to understand the issue better?

also for the test, add a comment that follows this format:

op2 -> cat1 -> cat2
will be better for people to follow as well

@facebook-github-bot
Copy link
Contributor

@jerryzh168 has imported this pull request. If you are a Meta employee, you can view this in D82597005.

- Add test case of implicit sharing for a model where one input is
  shared between two different ops.
- Add code comments to `_union_if_no_cycle`
@kimishpatel
Copy link
Contributor

The problem occurred when NodeOrEdge A with its own QuantizationSpec received an edge (in shared_with_map) to EdgeOrNode B

I dont follow why there A -> B edge if B's quant spec points to A

@martinlsm
Copy link
Contributor Author

Thanks, can you add a test in https://github.com/pytorch/ao/blob/main/test/quantization/pt2e/test_quantize_pt2e.py for us to understand the issue better?

also for the test, add a comment that follows this format:

op2 -> cat1 -> cat2

will be better for people to follow as well

I've added a test now. Have a look and see what you think :)

@martinlsm
Copy link
Contributor Author

The problem occurred when NodeOrEdge A with its own QuantizationSpec received an edge (in shared_with_map) to EdgeOrNode B

I dont follow why there A -> B edge if B's quant spec points to A

It's quite complicated, but I will attempt to explain. We have the following model:

class M(torch.nn.Module):
    def forward(self, x):
        a = x.clone()
        b = torch.eq(a, x)
        return b

And the following sitatuation of qspecs:

  • clone
    • Input (x,clone) has its own QuantizationSpec
    • Output clone shares qspec with input: SharedQuantizationSpec((x,clone))
  • eq
    • First input (clone,eq) has its own QuantizationSpec
    • Second input (x,eq) shares spec with the first input: SharedQuantizationSpec((clone,eq))
    • Output eq is not quantized (bool output)

Then algorithm _get_edge_or_node_to_group_id produces extra edges in shared_with_map when implicit sharing is enabled. These edges point downstream in the graph.

This is what happens when the problem occurs:

  1. The algorithm processes the first edge (x,clone): implicit sharing will produce edge shared_with_map[(x,clone)] = (x,eq) because eq is another user of x with an identical qspec.
  2. The algorithm processes the second edge (clone,eq): implicit sharing will cause (clone,eq) to be shared with clone (sharing with producer node), but clone has a SharedQuantizationSpec((x,clone)), and in step 1, we made (x,clone) share with (point to) (x,eq). So (x,eq) is the node we end up at and we produce the edge shared_with_map[(clone,eq)] = (x,eq)
  3. Now things has gone bad because when the algorithm processes the last edge (x,eq) and tries to find its qspec, it gets stuck in a loop because it looks up (clone,eq) via its shared qspec, but ends back again at (x,eq) because of the edge we produced in step 2. This is the endless loop.

Maybe there's a smarter way to solve this problem but I could only come up with the solution to skip producing the edge at step 2. If we skip forming an edge in step 2, the nodes/edges will still be unioned in the end because when the last edge (x,eq) is processed, we will union it with (clone,eq) by forming the edge in the correct direction (shared_with_map[(x,eq)] = (clone,eq), and not the other way around as we did in step 2.).

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 18, 2025

3. Now things has gone bad because when the algorithm processes the last edge (x,eq) and tries to find its qspec, it gets stuck in a loop because it looks up (clone,eq) via its shared qspec, but ends back again at (x,eq) because of the edge we produced in step 2. This is the endless loop.

thanks for the detailed example! this is very helpful, could you add this explanation to the test as well?

looking at the example, seems like this is still a relatively simple case, what if it takes a few more steps for (clone, eq) to point back to (x, eq)?

It seems to me, at the high level, we need some ordering of the output nodes and input edges, so that instead of have the pointer always point from whatever edges we are processing to another node or edge, as we are doing in the implicit sharing code, we always follow the ordering instead. for example in this case, we have:

name                        ordering
x                                0
(x, clone)                 1
clone                        2
(clone, eq)               3
(x, eq)                      4
eq                             5

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering. let's say, we follow an order that's listed above (when sharing, only higher ordering entities will point to lower ordering entities, not the reverse), even if we need to union (clone, eq) (3) and (x, eq) (4) multiple times, we always union the higher with the lower, and we'll always have (x, eq) points to (clone, eq) for both step 2 and 3 since the ordering of (x, eq) is higher than (clone, eq).

We could add a new ordering map probably to assign the ordering for each entities (output node and input edge) in the graph I think.

@kimishpatel
Copy link
Contributor

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering.

I dont think thats the issue. Loop is really of (x, eq) -> (clone, eq) -> clone node -> (x, clone) -> (x, eq) loop.

@kimishpatel
Copy link
Contributor

  • The algorithm processes the first edge (x,clone): implicit sharing will produce edge shared_with_map[(x,clone)] = (x,eq) because eq is another user of x with an identical qspec.

Given (x, clone) has its own quant spec why do we end up with shared_with_map[(x, clone)] = (x, eq)? If anything I would have expected, after processing (x, clone), shared_with_map to be empty OR with implicit sharing shared_with_map[(x, eq)] = (x, clone). Basically which ever edge is annotated to have its own qspec should never be a key in the shared_with_map

@kimishpatel
Copy link
Contributor

  • The algorithm processes the first edge (x,clone): implicit sharing will produce edge shared_with_map[(x,clone)] = (x,eq) because eq is another user of x with an identical qspec.

Given (x, clone) has its own quant spec why do we end up with shared_with_map[(x, clone)] = (x, eq)? If anything I would have expected, after processing (x, clone), shared_with_map to be empty OR with implicit sharing shared_with_map[(x, eq)] = (x, clone). Basically which ever edge is annotated to have its own qspec should never be a key in the shared_with_map

Now in this First input (clone,eq) has its own QuantizationSpec becomes problematic because implicit sharing would force it to map to clone node. So maybe some ordering of which edge is the one owning qspec should be done. And maybe if you do explicit cycle finding then you would denote (x, eq) as the leader

@jerryzh168
Copy link
Contributor

  • The algorithm processes the first edge (x,clone): implicit sharing will produce edge shared_with_map[(x,clone)] = (x,eq) because eq is another user of x with an identical qspec.

Given (x, clone) has its own quant spec why do we end up with shared_with_map[(x, clone)] = (x, eq)? If anything I would have expected, after processing (x, clone), shared_with_map to be empty OR with implicit sharing shared_with_map[(x, eq)] = (x, clone). Basically which ever edge is annotated to have its own qspec should never be a key in the shared_with_map

even something annotated with its own spec, it could share with other nodes / edges when they are the "same", e.g.
op1 -> op2
\ --> op3

output node of op1, and input edge (op1, op2), input edge (op1, op3) are the same thing since we just need to insert one observer, when implicit sharing is enabled.

@jerryzh168
Copy link
Contributor

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering.

I dont think thats the issue. Loop is really of (x, eq) -> (clone, eq) -> clone node -> (x, clone) -> (x, eq) loop.

I think (clone, eq) points to (x, eq) IIUC, but cc @martinlsm to confirm, it might be helpful to write down edge_or_node_to_qspec and the exact content of shared_with_map after each step

@jerryzh168
Copy link
Contributor

my understanding is the following:

Screenshot 2025-09-18 at 17 01 21

@kimishpatel
Copy link
Contributor

btw can you not detect cycles explicitly? I think the issue is that we need to have one leader node that takes the responsibility of owning the qspec. In situations like it like "i dont have qspec, go ask this other guy" and then we have guy x, asking guy y, askin guy z, asking guy x. So it seems to me that for implicit sharing we have to denote one node in the cycle as leader.

@martinlsm
Copy link
Contributor Author

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering.

I dont think thats the issue. Loop is really of (x, eq) -> (clone, eq) -> clone node -> (x, clone) -> (x, eq) loop.

I agree with @kimishpatel . Even if we decide some order by sorting, there could be shared qspecs pointing in the reverse order and thus we could risk forming a loop if I understand things correctly.

@martinlsm
Copy link
Contributor Author

my understanding is the following:

Screenshot 2025-09-18 at 17 01 21

This is close to correct, except that in step 2., when processing (clone,eq), it is not shared with (x,eq) as another user. Another user in this context would be another user of clone, which there are none, i.e., eq is the only user of clone and that is what we are currently processing.

(clone,eq) instead ends up pointing to (x,eq) by sharing with previous output clone. clone is in the current state pointing to (x,eq) because:

  • clone has a shared qspec to (x,clone)
  • (x,clone) points to (x,eq) in shared_with_map because we formed that edge in the previous step when we processed (x,clone).

So when (clone,eq) wants to share with clone, it arrives at (x,eq).

@martinlsm
Copy link
Contributor Author

btw can you not detect cycles explicitly? I think the issue is that we need to have one leader node that takes the responsibility of owning the qspec. In situations like it like "i dont have qspec, go ask this other guy" and then we have guy x, asking guy y, askin guy z, asking guy x. So it seems to me that for implicit sharing we have to denote one node in the cycle as leader.

I just submitted a new version now that is kind of what you suggest here I guess. Instead of skipping forming the union, I instead reverse the edge in the _union. This feels a bit more robust since we still ensure that the nodes are unioned and it's less risk that we'll miss some edge case I guess...

looking at the example, seems like this is still a relatively simple case, what if it takes a few more steps for (clone, eq) to point back to (x, eq)?

Then regarding this good argument, I don't believe it is possible to be more steps than one because _union calls _find_root_edge_or_node on both the parent's and the child's subtree, and that function performs path compression so any indirections will be eliminated. Please correct me if I'm wrong here though because this stuff is complicated!

@martinlsm
Copy link
Contributor Author

thanks for the detailed example! this is very helpful, could you add this explanation to the test as well?

Oh, right I forgot about this one. Will sort that out on Monday. But in the meantime, let me know what you think about the new solution if you got the time.

@jerryzh168
Copy link
Contributor

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering.

I dont think thats the issue. Loop is really of (x, eq) -> (clone, eq) -> clone node -> (x, clone) -> (x, eq) loop.

I agree with @kimishpatel . Even if we decide some order by sorting, there could be shared qspecs pointing in the reverse order and thus we could risk forming a loop if I understand things correctly.

I think it similar, if we have ordering, then even if user specifies a sharing with a reverse direction, we'll have the node/edge with a lower ordering_id to hold the final qspec, so we will have to modify the edge_or_node_to_qspec dictionary as well

personally I feel the ordering map will help make it easier since no need to worry about cycle detection anymore, but both works I feel

Comment on lines +164 to +166
# Parent already references child with a shared qspec. We would create
# a cycle if we formed an edge from the child to the parent. Therefore,
# we reverse the edge in this particular case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be more confusing than just assign an ordering before hand?

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I haven't thought through, but wondering if it's possible that root_child can go around and end up pointing to root_parent again

Copy link
Contributor

@jerryzh168 jerryzh168 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kimishpatel is this the cycle detection you have in mind?

seems OK to me, if this is the only thing that's needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. Isnt the cycle already formed (parent_qspec.edge_or_node == root_child) before we come here. it feels we are detecting that and correcting it. I might be wrong though. I

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel Prior to forming the problematic union, we are in the state that is shown in figure below about to assign shared_with_map[(clone,eq)] = (x,eq). So there's not already a union we are correct. We are just reversing the green edge by assigning shared_with_map[(x,eq)] = (clone,eq) to make the edge point in the same direction as the blue one (edge_or_node_to_qspec).

prepare_state (2)

@kimishpatel
Copy link
Contributor

the problem right now is (clone, eq) points to (x, eq), and then later (x, eq) points to (clone ,eq) again, forming a loop, due to lack of ordering.

I dont think thats the issue. Loop is really of (x, eq) -> (clone, eq) -> clone node -> (x, clone) -> (x, eq) loop.

I agree with @kimishpatel . Even if we decide some order by sorting, there could be shared qspecs pointing in the reverse order and thus we could risk forming a loop if I understand things correctly.

I think it similar, if we have ordering, then even if user specifies a sharing with a reverse direction, we'll have the node/edge with a lower ordering_id to hold the final qspec, so we will have to modify the edge_or_node_to_qspec dictionary as well

personally I feel the ordering map will help make it easier since no need to worry about cycle detection anymore, but both works I feel

I think cycle detection is probably more easy to reason about. it is clear in its intent. Unless there is a significant concern on other aspect, I would suggest we do cycle detection

@martinlsm
Copy link
Contributor Author

thanks for the detailed example! this is very helpful, could you add this explanation to the test as well?

Oh, right I forgot about this one. Will sort that out on Monday. But in the meantime, let me know what you think about the new solution if you got the time.

I have done an update regarding this now.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems like it's the only change that's needed to prevent a loop?

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok looks good

@martinlsm
Copy link
Contributor Author

LGTM, seems like it's the only change that's needed to prevent a loop?

I think it's the only thing needed yes! Proving it would be hard but I have not yet found a case where it does not work. But if this problem is ever to be seen again, it would just have to revisited I guess.

So we are good to go and can merge this? Thanks guys for the useful feedback and discussion in this PR!

@jerryzh168 jerryzh168 added topic: bug fix Use this tag for PRs that fix bugs pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e) labels Sep 26, 2025
@jerryzh168 jerryzh168 merged commit 0b96757 into pytorch:main Sep 27, 2025
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e) topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants